from itertools import product
import math
import random
import time

import einops
import numpy as np
import pytest
from scipy.stats import norm
import torch

import bitsandbytes as bnb
from bitsandbytes import functional as F
from tests.helpers import (
    BOOLEAN_TUPLES,
    TRUE_FALSE,
    describe_dtype,
    get_test_dims,
    id_formatter,
)

torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
k = 20


def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
    sumval = (idx == 0).sum().item()
    if sumval > count:
        if throw:
            print(f"Too many values not close: assert {sumval} < {count}")
            torch.testing.assert_close(a, b, rtol=rtol, atol=atol)

    return sumval


class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class Timer:
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

    def tick(self, name="default"):
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

    def tock(self, name="default", evict=True, print_ms=True):
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
            if name not in self.agg:
                self.agg[name] = 0.0
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")

        return self.agg[name]

    def reset(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}
        print("Resetting benchmark data")


def setup():
    pass


def teardown():
    pass


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
def test_estimate_quantiles(dtype):
    A = torch.rand(1024, 1024, device="cuda")
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
    torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)

    A = torch.randn(1024, 1024, device="cuda")
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
    diff = torch.abs(code - quantiles)
    assert (diff > 5e-02).sum().item() == 0


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
    diffs = []
    reldiffs = []
    for i in range(100):
        A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
    abserr = sum(diffs) / len(diffs)
    relerr = sum(reldiffs) / len(reldiffs)
    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
    assert abserr < 0.011
    assert relerr < 0.018
    assert A2.dtype == dtype

    diffs = []
    code = F.create_dynamic_map(signed=signed)
    for i in range(100):
        A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
        A2 = F.dequantize_blockwise(C, S)
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
    abserr = sum(diffs) / len(diffs)
    relerr = sum(reldiffs) / len(reldiffs)
    if signed:
        assert abserr < 0.0035
        assert relerr < 0.015
    else:
        assert abserr < 0.00175
        assert relerr < 0.012
    assert A2.dtype == dtype
    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))


def quant(x):
    max1 = torch.abs(x).max()
    x = torch.round(x / max1 * 127)
    return max1, x.to(torch.int8)


def dequant(c, maxC):
    return c.float() * (maxC / 127)


def mm_dequant(maxA, maxB, C):
    return C.float() * (maxA / 127) * (maxB / 127)


def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
    return max1, x.to(torch.int8)


def quant_multi_chunk(x, dim, chunk_size=32):
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
    return max1, x.to(torch.int8)


def quant_minmax(A):
    minA = A.min()
    maxA = A.max()


def mean(xx):
    return sum(xx) / float(len(xx))


methods = {
    "linear": (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    ),
    "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant),
}


@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
    # print("")
    for i in range(5):
        if batched:
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
        torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
    # print(mean(errors))
    # print(mean(relerrors))


def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
        shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())

        torch.testing.assert_close(out.float(), out2)

    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

        torch.testing.assert_close(out.float(), out2)


@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
        A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
        iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
        out = F.igemm(A, B, out=iout)

        torch.testing.assert_close(out.float(), out2)


@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
        A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
        if transpose:
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
        else:
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
            out = out.float()
            out = (out * maxB.t() * scale / (127 * 127)) + offset

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
            offset = B.sum(0) * (minA + scale)
            out = F.igemm(Ac, Bc)
            out2 = torch.matmul(A, B)
            out = out.float()
            out = (out * maxB * scale / (127 * 127)) + offset

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)

        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3


@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
        torch.testing.assert_close(out.float(), out2.float())


@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
        elif dims == 3:
            A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
        C1 = torch.matmul(A.float(), B.t().float())

        C2 = F.int8_linear_matmul(A, B)
        torch.testing.assert_close(C1, C2.float())


@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims):
    for i in range(k):
        if dims == 2:
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
        elif dims == 3:
            A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
        B = torch.randn((dim4, dim3), device="cuda").half()
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, _, statsA, _, _ = F.int8_double_quant(A)
        CB, statsB, _ = F.int8_vectorwise_quant(B)
        output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)

        torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)


@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_dequant_mm(dim1, dim4, dims, has_bias):
    inner = 128
    bias = None
    if has_bias:
        bias = torch.randn(dim4, device="cuda", dtype=torch.float16)

    for i in range(1):
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
        C1 = torch.matmul(A.half(), B.t().half())
        if has_bias:
            C1 += bias

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

        C2 = F.int8_linear_matmul(A1, B1)

        C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
        if has_bias:
            C4 += bias

        # TODO: is something wrong here? If so, the problem goes deeper
        # n = C1.numel()
        # p = 0.06
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
        # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
        # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"

        C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias)
        C5 /= std
        torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
        n = C5.numel()
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))


@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
def test_colrow_absmax(dim1, dim2, dims, threshold):
    for i in range(k):
        A = torch.randn(dim1, dim2, device="cuda").half()

        assert dims == 2

        row_stats1, _ = torch.abs(A.float()).max(1)
        col_stats1, _ = torch.abs(A.float()).max(0)

        if threshold > 0.0:
            A_truncated = A.clone()
            A_truncated[torch.abs(A_truncated) >= threshold] = 0.0
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)

            row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)

            nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten()
            nnz_block_ptr1 = torch.zeros(
                nnz_rows1_counts.shape[0] + 1,
                dtype=nnz_rows1_counts.dtype,
                device=nnz_rows1_counts.device,
            )
            nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

            torch.testing.assert_close(col_stats1_trunc, col_stats2)
            torch.testing.assert_close(row_stats1_trunc, row_stats2)
            # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2)
        else:
            row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
            assert nnz_block_ptr2 is None
            torch.testing.assert_close(col_stats1, col_stats2)
            torch.testing.assert_close(row_stats1, row_stats2)


@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
def test_int8_double_quant(dim1, dim2):
    for i in range(k):
        A = torch.randn(dim1, dim2, device="cuda").half()
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A)

        # max difference is 1 due to rounding differences
        torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)

        n = CAt.numel()
        num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()

        # allow for 1:500 error due to rounding differences
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
            assert False
        if num_not_close_rows > (min_error * n):
            print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
            assert False

        torch.testing.assert_close(Srow.flatten().float(), statsA)
        torch.testing.assert_close(Scol.flatten().float(), statsAt)


@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
        for (dim1, dim4, inner) in zip(
            (1, 8, 2048, 4096),
            (2, 128, 2048, 4096),
            (4, 256, 512, 4096),
        )
    ),
)
def test_integrated_int8_linear_matmul(dim1, dim4, inner):
    for i in range(k):
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, stats1a, _ = F.int8_vectorwise_quant(A)
        C2a, stats2a, _ = F.int8_vectorwise_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

        torch.testing.assert_close(maxA.flatten().float(), stats1a)
        torch.testing.assert_close(maxB.flatten().float(), stats2a)
        torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_close(C2a, B1, rtol=0, atol=1)

        out2 = F.int8_linear_matmul(A1, B1)

        C2 = F.int8_linear_matmul(A1, B1)

        out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())

        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
        assert err2 <= err1 * 1.025


@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
        for (dim1, dim4, inner) in zip(
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
        )
    ),
)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
        outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
        # C3, S = F.nvidia_transform(outC32, "row", state=SC)
        C3 = outC32
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32 = F.int8_linear_matmul(A2, B2)
        out2 = F.int8_mm_dequant(outC32, stats1a, stats2a)

        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")

        C = torch.matmul(CA.float(), CB.t().float())
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)

        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)

        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())

        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))


@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(dim1, dim2):
    threshold = 2.00
    for i in range(k):
        A = torch.randn(dim1, dim2, device="cuda").half()

        idx = torch.abs(A) >= threshold
        CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)

        if outlier_cols is not None:
            A1 = A * idx
            A2 = torch.zeros_like(A) + A1
            torch.testing.assert_close(A1, A2)

            A[:, outlier_cols] = 0
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
            torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)


@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_int8_vectorwise_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
        A = torch.randn(dim1, dim2, device="cuda").half()

        idx = torch.abs(A) >= threshold
        CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)

        if outlier_cols is not None:
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
            A[:, outlier_cols] = 0
            torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)


@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
    # dim3 = 17
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
        cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
        A2 = A * idx

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


@pytest.mark.benchmark
def test_spmm_bench():
    batch = 2
    model = 1024 * 1
    hidden = model * 4
    seq = 1024
    dim1 = batch * seq
    dim2 = model
    dim3 = hidden
    threshold = 4
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
    for i in range(10):
        C1 = bnb.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = bnb.matmul(A, B.t())
    torch.cuda.synchronize()
    t8 = time.time() - t0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    print(nnz / idx.numel())
    rows, cols = torch.where(idx)
    values = A[idx]
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)

    for i in range(10):
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
    tsp = time.time() - t0
    print(tsp, t8)
    print(tsp / t8)


@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
    for _ in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
        CA, statsA, _ = F.int8_vectorwise_quant(A)

        out1_32 = F.int8_linear_matmul(CA, Cw1)
        out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)

        # CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
        CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)

        out1_32 = F.int8_linear_matmul(CA, Cw1)
        out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        # idx = torch.unique(coo_tensor._indices()[1]).long()
        # out4 = torch.matmul(A, w1.t())
        out5 = out3 + out4

        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
        assert err2 < err1


def test_matmuls():
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
    c2 = bnb.matmul(a, b)
    c3 = bnb.matmul_cublas(a, b.t())

    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
    assert err1 < 0.2
    assert err2 < 0.2
    print(err1, err2)


@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    if dtype == torch.float16:
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
        torch.nn.init.xavier_uniform_(B)
    else:
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
        torch.nn.init.xavier_uniform_(B)
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)

    print("")
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
    A2 = A * idx
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
    n = out1.numel()
    count = math.ceil(p * n)
    std = out1.std()
    out1 /= std
    out2 /= std
    assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

    # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)

    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

    # torch.cuda.synchronize()
    # print(time.time() - t0)


def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
    A2 = A * idx
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

    torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
    idx = A2 != 0
    torch.testing.assert_close(A2[idx], csrA.values)


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
    A2 = A * idx
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

    torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
    # torch uses row-major -> use transpose to transfer to col-major
    idx = A2.t() != 0
    torch.testing.assert_close(A2.t()[idx], cscA.values)


@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
    A2 = A * idx
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
    out3 = out3 * statsBt.half() / 127

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

    torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)

    p = 200 / (2048 * 12288 * 4)
    n = out1.numel()
    count = math.ceil(p * n)
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
    print("cusparse fp16", time.time() - t0)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
    torch.cuda.synchronize()
    print("int8", time.time() - t0)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    torch.cuda.synchronize()
    print("int8+dequant", time.time() - t0)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out2 = torch.matmul(A, B)
    torch.cuda.synchronize()
    print("matmul", time.time() - t0)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
        out = out1 + out2
    torch.cuda.synchronize()
    print("sparse+ matmul", time.time() - t0)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
    print("partial matmul", time.time() - t0)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
    print("partial matmul", time.time() - t0)


def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
        minx = x.min()
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
        return x, qx, zpx

    batch = 2
    seq = 512
    model = 1024
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1

    C0 = torch.matmul(A, B)

    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp

    ca, cqa, cza = quant_zp(A)
    # print(ca.min(), ca.max())
    # print((ca - cza).min(), (ca - cza).max())

    zp = 1
    scale = 2.0
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
    C4 -= B.sum(0) * zpa
    C4 /= qa

    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb

    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb

    # print("")
    # print(C0.flatten()[:10])
    # print(C1.flatten()[:10])
    # print(C2.flatten()[:10])
    # print(C3.flatten()[:10])
    # print(C5.flatten()[:10])
    # print(C6.flatten()[:10])
    # print(C7.flatten()[:10])
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
    print(err1, err2, err3, err4, err5, err6)


@pytest.mark.deprecated
def test_extract_outliers():
    for i in range(k):
        shapeA = (4096, 4096 * 4)
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        outliers1 = A[:, idx.long()]

        CA, SA = F.transform(A, "col_turing")

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()

        torch.testing.assert_close(outliers1, outliers2)

        CA, SA = F.transform(A, "col_ampere")

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()

        torch.testing.assert_close(outliers1, outliers2)


def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
    for hidden in [128]:  # , 14336]:
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device="cpu")
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))


def test_fp8_quant():
    for e_bits in range(1, 7):
        p_bits = 7 - e_bits
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff / torch.abs(A1 + 1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            # assert diff < 0.0075
        # print(sum(abserr)/len(abserr))
        # print(sum(relerr)/len(relerr))

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff / torch.abs(A1 + 1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            # assert diff < 0.0075
        # print(sum(abserr)/len(abserr))
        # print(sum(relerr)/len(relerr))

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff / torch.abs(A1 + 1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            # assert diff < 0.0075
        # print(3, sum(abserr)/len(abserr))
        # print(3, sum(relerr)/len(relerr))


def test_few_bit_quant():
    # print('')
    for bits in range(2, 9):
        # print('='*30, bits, '='*30)
        for method in ["linear", "fp8", "dynamic", "quantile"]:
            abserrs = []
            relerrs = []
            code = None
            if method == "linear":
                code = F.create_linear_map(True, total_bits=bits).cuda()
            elif method == "fp8":
                ebits = math.ceil(bits / 2)
                pbits = bits - ebits - 1
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
            elif method == "dynamic":
                code = F.create_dynamic_map(True, bits - 0, bits).cuda()
            elif method == "quantile":
                values = torch.randn(2048, 2048, device="cuda")
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
            assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
            # print(method, (code==0).sum())
            assert code.numel() == 256
            for i in range(10):
                values = torch.randn(1, 32, device="cuda")
                values /= values.abs().max()
                # values[values.abs() < 1e-6] += 1e-5

                q1 = []
                v1 = []
                for v in values[0]:
                    idx = torch.abs(v - code).argmin()
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)

                idx = torch.isclose(q1.int(), q2.int())
                err2 = torch.abs(v2 - values)
                abserrs.append(err2.mean().item())
                relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
                if idx.sum():
                    # some weird cases
                    err1 = torch.abs(v1 - values).mean()
                    # assert err2.mean() <= err1

                else:
                    torch.testing.assert_close(q1, q2)
            # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
    # assert False


def test_kbit_quantile_estimation():
    for i in range(100):
        data = torch.randn(1024, 1024, device="cuda")
        for bits in range(2, 9):
            p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
            err = torch.abs(val1 - val2).mean()
            assert err < 0.038

    for i in range(100):
        data = torch.randn(1024, 1024, device="cuda")
        for bits in range(2, 4):
            total_values = 2**bits - 1
            p = np.linspace(0, 1, 2 * total_values + 1)
            idx = np.arange(1, 2 * total_values + 1, 2)
            p = p[idx]
            offset = 1 / (2 * total_values)
            p = np.linspace(offset, 1 - offset, total_values)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
            err = torch.abs(val1 - val2).mean()
            assert err < 0.035


@pytest.mark.benchmark
def test_bench_dequantization():
    a = torch.rand(1024, 1024, device="cuda").half()
    code = F.create_fp8_map(True, 3, 0, 4).cuda()
    qa, SA = F.quantize_blockwise(a, code=code)
    print(qa.max())

    max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
    # print(max_theoretical_mu)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        qa, SA = F.quantize_blockwise(a)
    torch.cuda.synchronize()
    # print((time.time()-t0)/1e6)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(dtype, quant_type, blocksize):
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
        idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
        sign = -1.0 if sign else 1.0
        exp = e1 * 2 + e2 * 1
        if exp == 0:
            # sub-normal
            if p1 == 0:
                result = 0
            else:
                result = sign * 0.0625
        else:
            # normal
            exp = 2 ** (-exp + bias + 1)
            frac = 1.5 if p1 else 1.0
            result = sign * exp * frac
        code[idx] = result

    A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
    qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
    A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)

    err = (A1 - A2).abs().float()
    relerr = (err / (A1.abs().float() + 1e-8)).mean()
    idx = err > 1.0
    err = err.mean()

    assert A2.dtype == dtype

    # With larger block sizes, we can expect this to blow up.
    # At blocksize>=1024, don't even bother looking at relerr.
    if blocksize <= 64:
        assert err.item() < 0.1
        assert relerr.item() < 0.28
    elif blocksize <= 256:
        assert err.item() < 0.11
        assert relerr.item() < 0.30
    elif blocksize <= 512:
        assert err.item() < 0.12
        assert relerr.item() < 0.31
    elif quant_type == "fp4":
        # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
        assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
    else:
        # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
        assert err.item() < math.log2(blocksize) * 8e-2


@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
def test_4bit_compressed_stats(quant_type):
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
        for i in range(10):
            A1 = torch.randn(1024, 1024, device="cuda").half()
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
            q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)

            err = (A1 - A2).abs().float()
            relerr = (err / (A1.abs().float() + 1e-15)).mean()
            err = err.mean()

            errs1.append(err.item())

            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
            relerr = (err / (A1.abs().float() + 1e-15)).mean()
            err = err.mean()

            errs2.append(err.item())

            assert err.item() < 0.11
            assert relerr.item() < 0.28

        # print(sum(errs1)/len(errs1), blocksize, quant_type)
        # print(sum(errs2)/len(errs2), blocksize, quant_type)


# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.benchmark
def test_bench_4bit_dequant(quant_type):
    blocksize = 256
    a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)

    input_size = a.numel() / 2
    output_size = a.numel() * 2
    num_bytes = input_size + output_size
    GB = num_bytes / 1e9
    max_theoretical_s = GB / 768
    # print(max_theoretical_s*1e6)
    b = torch.randn(128, 1024 * 12, device="cuda").half()

    iters = 100
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
        # b.copy_(a)
    torch.cuda.synchronize()
    # print((time.time()-t0)/iters*1e6)

    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
    #    torch.matmul(b, a.t())
    # torch.cuda.synchronize()
    # print((time.time()-t0)/iters*1e6)


def test_normal_map_tree():
    code = F.create_normal_map()
    values = code[:8].tolist() + code[-8:].tolist()
    num_pivots = 1
    # print(values)
    while num_pivots < 16:
        idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
        # print(idx)
        num_pivots *= 2
        pivots = []
        for i in idx:
            pivots.append((values[i - 1] + values[i]) / 2)
        # print(pivots)


@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize(
    "quant_storage",
    [torch.uint8, torch.float16, torch.bfloat16, torch.float32],
    ids=describe_dtype,
)
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
    for dim in [128, 256, 512, 1024]:
        # for dim in [4*1024]:
        # for dim in [1*16]:
        errs1 = []
        errs2 = []
        errs3 = []
        relerrs1 = []
        relerrs2 = []
        relerrs3 = []
        max_errs1 = []
        max_errs2 = []
        max_errs3 = []

        for i in range(100):
            if kind == "fc1":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "fc2":
                A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
                B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "attn":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "attn_packed":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)

            qB, state = F.quantize_4bit(
                B,
                quant_type=storage_type,
                compress_statistics=double_quant,
                quant_storage=quant_storage,
            )
            C3 = torch.matmul(A, B.t())
            C2 = F.gemv_4bit(A, qB.t(), state=state)
            A.requires_grad = True
            C1 = bnb.matmul_4bit(A, qB.t(), state)

            err1 = (C1 - C2).abs().float()
            err2 = (C3 - C2).abs().float()
            err3 = (C3 - C1).abs().float()

            mag1 = torch.abs(C1).float() + 1e-5
            mag2 = torch.abs(C3).float() + 1e-5
            mag3 = torch.abs(C3).float() + 1e-5

            relerr1 = err1 / mag1
            relerr2 = err2 / mag2
            relerr3 = err3 / mag3

            max_err1 = err1.max()
            max_err2 = err2.max()
            max_err3 = err3.max()

            errs1.append(err1.mean().item())
            errs2.append(err2.mean().item())
            errs3.append(err3.mean().item())

            relerrs1.append(relerr1.mean().item())
            relerrs2.append(relerr2.mean().item())
            relerrs3.append(relerr3.mean().item())

            max_errs1.append(max_err1.item())
            max_errs2.append(max_err2.item())
            max_errs3.append(max_err3.item())

            c = int(C1.numel() * 0.0014 * (dim / 256)) + 1

            c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
        err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
        err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
        err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
        relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
        relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
        relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
        maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
        maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
        maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
        absratio = err2 / err3
        relratio = relerr2 / relerr3
        maxratio = relerr2 / relerr3

        # for debugging if the tests fails
        #
        # print('='*80)
        # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
        # print(C1.flatten()[-20:])
        # print(C2.flatten()[-20:])
        # print(f'inference vs training abs: {err1}')
        # print(f'inference vs training rel: {relerr1}')
        # print(f'inference vs training max: {maxerr1}')
        # print(f'inference vs training vs torch err ratio abs: {absratio}')
        # print(f'inference vs training vs torch err ratio rel: {relratio}')
        # print(f'inference vs training vs torch err ratio max: {maxratio}')
        if dtype == torch.float16:
            if dim <= 512:
                assert err1 < 7e-5
                assert relerr1 < 0.0008
            else:
                assert err1 < 6e-5
                assert relerr1 < 2e-4
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
        elif dtype == torch.float32:
            if dim <= 512:
                assert err1 < 5e-8
                assert relerr1 < 1e-6
                assert maxerr1 < 1e-7
            else:
                assert err1 < 5e-8
                assert relerr1 < 8e-6
                assert maxerr1 < 1e-7
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
        elif dtype == torch.bfloat16:
            if dim <= 512:
                assert err1 < 6e-4
                assert relerr1 < 0.007
                assert maxerr1 < 0.015
            else:
                assert err1 < 2e-4
                assert relerr1 < 0.002
                assert maxerr1 < 0.0012
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.04 and relratio > 0.96
            assert maxratio < 1.02 and maxratio > 0.98


@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
    n = 32 * 10
    A = F.get_paged(n, n, dtype=torch.float32)
    B = F.get_paged(n, n, dtype=torch.uint8)
    B2 = F.get_paged(n, n, dtype=torch.float32)
    assert A.is_paged
    assert B.is_paged
    assert A.page_deviceid == 0
    assert B.page_deviceid == 0
    F.fill(A, 17.0)
    F.fill(B, 17)
    F.fill(B2, 2)
    assert (A == 17).sum().item() == n * n
    assert (B == 17).sum().item() == n * n
    C = A * B.float()
    assert (C == 289).sum().item() == n * n
    F._mul(A, B2)
    F._mul(A, B2)
    F._mul(A, B2)
    assert (A == 17 * (2**3)).sum().item() == n * n


@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
    dims = 10
    torch.random.manual_seed(np.random.randint(0, 412424242))
    dims = get_test_dims(0, 8192, n=dims)
    dims = [dim + (64 - (dim % 64)) for dim in dims]
    # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
    for dim in dims:
        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
        B = torch.eye(dim, dtype=dtype, device="cuda")

        qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
        C3 = torch.matmul(A, B.t())
        C2 = bnb.matmul_4bit(A, qB.t(), state)
        A.requires_grad = True
        C1 = bnb.matmul_4bit(A, qB.t(), state)

        torch.testing.assert_close(A, C3)
        torch.testing.assert_close(A, C1)
        torch.testing.assert_close(A, C2)
        # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
        # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)


@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
@pytest.mark.deprecated
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
        A = torch.randn(size=(dim2, dim3), device="cuda")
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))


@pytest.mark.deprecated
def test_quantile_quantization():
    for i in range(100):
        A1 = torch.randn(1024, 1024, device="cuda")
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
        diff = torch.abs(A1 - A2).mean().item()
        assert diff < 0.0075

        A1 = torch.rand(1024, 1024, device="cuda")
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
        diff = torch.abs(A1 - A2).mean().item()
        torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
        assert diff < 0.001


@pytest.mark.deprecated
def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
        A1 = torch.randn(1024, 1024, device="cuda")
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
    print(sum(diffs) / len(diffs))
    print(sum(reldiffs) / len(reldiffs))

    for i in range(100):
        A1 = torch.rand(1024, 1024, device="cuda")
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
        diff = torch.abs(A1 - A2).mean().item()
        torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
        assert diff < 0.004


@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_percentile_clipping(gtype):
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
    n = 4
    step = 0
    percentile = 5
    for i in range(k):
        step += 1
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

        torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_close(clip1, clip2)
        torch.testing.assert_close(gnorm1, gnorm2)


@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
        elif dims == 3:
            A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
        # print(out1)
        # print(out2)

        torch.testing.assert_close(out1, out2)


@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
@pytest.mark.deprecated
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    if dims == 3 and orderOut != "col32":
        return
    if dtype == torch.int32 and orderOut != "col32":
        return
    try:
        func = F.get_transform_func(dtype, orderA, orderOut, transpose)
    except ValueError as ve:
        pytest.skip(str(ve))  # skip if not supported

    if dims == 2:
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
    elif dims == 3:
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)

    out, S = F.nvidia_transform(A, to_order=orderOut)

    if orderOut == "row":
        torch.testing.assert_close(A.flatten(), out.flatten())
    elif orderOut == "col":
        torch.testing.assert_close(A.t().flatten(), out.flatten())
    elif orderOut == "col32":
        if dims == 2:
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
        elif dims == 3:
            n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
        assert out.numel() == n
    elif orderOut == "col_turing":
        # 32 col 8 row tiles
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
                i = row * A.shape[1]
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
                rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
                offset = 32 * 8 * (rowtile + coltile)
                col2 = col % 32
                row2 = (row % 8) * 32

                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
                # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])

    if orderOut == "col32":
        out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
        torch.testing.assert_close(A, out2)
